{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting MNIST_split.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile MNIST_split.py\n",
    "\n",
    "from __future__ import print_function, division\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "import numpy as np\n",
    "import torchvision\n",
    "from torchvision import datasets, models, transforms\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "import copy\n",
    "import os\n",
    "import sys\n",
    "\n",
    "\n",
    "import shutil\n",
    "import pdb\n",
    "\n",
    "import torch.utils.data as data\n",
    "from PIL import Image\n",
    "import os\n",
    "import os.path\n",
    "import errno\n",
    "import torch\n",
    "import codecs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Appending to MNIST_split.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile -a MNIST_split.py\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "class MNIST_Split(datasets.MNIST):\n",
    "    \"\"\"`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.\n",
    "\n",
    "    Args:\n",
    "        root (string): Root directory of dataset where ``processed/training.pt``\n",
    "            and  ``processed/test.pt`` exist.\n",
    "        train (bool, optional): If True, creates dataset from ``training.pt``,\n",
    "            otherwise from ``test.pt``.\n",
    "        download (bool, optional): If true, downloads the dataset from the internet and\n",
    "            puts it in root directory. If dataset is already downloaded, it is not\n",
    "            downloaded again.\n",
    "        transform (callable, optional): A function/transform that  takes in an PIL image\n",
    "            and returns a transformed version. E.g, ``transforms.RandomCrop``\n",
    "        target_transform (callable, optional): A function/transform that takes in the\n",
    "            target and transforms it.\n",
    "    \"\"\"\n",
    "  \n",
    "    \n",
    "\n",
    "    def __init__(self, root, train=True, transform=None, target_transform=None, download=False,digits=[1,2]):\n",
    "        super(MNIST_Split, self).__init__(root, train, transform, target_transform, download)\n",
    "\n",
    "\n",
    "\n",
    "        #get only the two digits\n",
    "        self.digit_labels=None\n",
    "        self.digit_data=None\n",
    "        self.classes= digits \n",
    "        if self.train:\n",
    "            \n",
    "            #loop over the given digits and extract there corresponding data\n",
    "            for digit in digits:\n",
    "                digit_mask=torch.eq(self.train_labels , digit) \n",
    "                digit_index=torch.nonzero(digit_mask)\n",
    "                digit_index=digit_index.view(-1)\n",
    "                this_digit_data=self.train_data[digit_index]\n",
    "                this_digit_labels=self.train_labels[digit_mask]\n",
    "                this_digit_labels.fill_(digits.index(digit))\n",
    "                if self.digit_data is None:\n",
    "                    self.digit_data=this_digit_data.clone()\n",
    "                    self.digit_labels=this_digit_labels.clone()\n",
    "                else:\n",
    "                    self.digit_data=torch.cat((self.digit_data,this_digit_data),0)\n",
    "                    self.digit_labels=torch.cat((self.digit_labels,this_digit_labels),0)\n",
    "            #self.train_data, self.train_labels = torch.load(\n",
    "                #os.path.join(root, self.processed_folder, self.training_file))\n",
    "        else:\n",
    "                       #loop over the given digits and extract there corresponding data\n",
    "            for digit in digits:\n",
    "                digit_mask=torch.eq(self.test_labels , digit) \n",
    "                digit_index=torch.nonzero(digit_mask)\n",
    "                digit_index=digit_index.view(-1)\n",
    "                this_digit_data=self.test_data[digit_index]\n",
    "                this_digit_labels=self.test_labels[digit_mask]\n",
    "                this_digit_labels.fill_(digits.index(digit))\n",
    "                if self.digit_data is None:\n",
    "                    self.digit_data=this_digit_data.clone()\n",
    "                    self.digit_labels=this_digit_labels.clone()\n",
    "                    \n",
    "                else:\n",
    "                    self.digit_data=torch.cat((self.digit_data,this_digit_data),0)\n",
    "                    self.digit_labels=torch.cat((self.digit_labels,this_digit_labels),0)\n",
    "                    \n",
    "         \n",
    "        \n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            index (int): Index\n",
    "\n",
    "        Returns:\n",
    "            tuple: (image, target) where target is index of the target class.\n",
    "        \"\"\"\n",
    "       \n",
    "        img, target = self.digit_data[index], self.digit_labels[index]\n",
    "       \n",
    "\n",
    "        # doing this so that it is consistent with all other datasets\n",
    "        # to return a PIL Image\n",
    "        img = Image.fromarray(img.numpy(), mode='L')\n",
    "\n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "\n",
    "        if self.target_transform is not None:\n",
    "            target = self.target_transform(target)\n",
    "        \n",
    "        img=img.view(-1,28*28) \n",
    "       \n",
    "        return img, target\n",
    "\n",
    "    def __len__(self):\n",
    "        return(self.digit_labels.size(0)) \n",
    "\n",
    "  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
